/*
 * Written by Dawid Kurzyniec and released to the public domain, as explained
 * at http://creativecommons.org/licenses/publicdomain
 */

package edu.emory.mathcs.util.net.tunnel;

import java.io.*;
import java.net.*;
import java.util.*;

import edu.emory.mathcs.util.io.*;
import edu.emory.mathcs.util.collections.ints.*;
import edu.emory.mathcs.backport.java.util.concurrent.*;
import edu.emory.mathcs.util.concurrent.*;
import java.security.AccessController;
import java.security.PrivilegedAction;

/**
 * @author Dawid Kurzyniec
 * @version 1.0
 */
public abstract class TunnelServerSocket extends ServerSocket {

    static int PROTOCOL_MAGIC     = 0x145ad09A;
    static int PROTOCOL_VERSION_1 = 0x00000001;

    static final int ACK = 1;
    static final int NACK_UNSUPPORTED_PROTOCOL = -1;
    static final int NACK_NO_SERVER = -2;
    static final int NACK_SERVER_CLOSED = -3;
    static final int NACK_QUEUE_FULL = -4;

    static final int DEFAULT_DISPATCHER_BACKLOG = 200;

    final static Map dispatchers = new HashMap();

    public static interface TunnelFactory {
        ServerSocket createTunnel(SocketAddress addr) throws IOException;
    }

    volatile boolean isBinding;
    volatile boolean closed;
    boolean bound;
    Dispatcher dispatcher;
    int localPort;
    int soTimeout = 0;
    int receiveBufferSize = -1;

    volatile BlockingQueue connectionQueue;
    final TunnelFactory tunnelFactory;
    ServerSocket tunnel; // set when bound

    private final static TunnelFactory plainTunnelFactory = new PlainTunnelFactory();

    public TunnelServerSocket(TunnelFactory tunnelFactory) throws IOException {
        super();
        this.tunnelFactory = tunnelFactory;
    }

    public TunnelServerSocket(TunnelFactory tunnelFactory,
                              TunnelSocketAddress addr) throws IOException {
        this(tunnelFactory, addr, 50);
    }

    public TunnelServerSocket(TunnelFactory tunnelFactory,
                              TunnelSocketAddress addr, int backlog)
        throws IOException
    {
        this(tunnelFactory);
        bind(addr, backlog);
    }

    public TunnelServerSocket() throws IOException {
        this(plainTunnelFactory);
    }

    public TunnelServerSocket(InetSocketAddress tunnelAddr, int port) throws IOException {
        this(plainTunnelFactory, new TunnelSocketAddress(tunnelAddr, port));
    }

    public void bind(SocketAddress endpoint, int backlog) throws IOException {
        synchronized (this) {
            ensureNotClosed();
            if (isBound()) throw new SocketException("Already bound");
            if (isBinding) throw new SocketException("Already binding");
            if (endpoint == null) {
                endpoint = new TunnelSocketAddress(new InetSocketAddress(0), 0);
            }
            if (backlog <= 0) throw new IllegalArgumentException("Backlog must be positive");
            if (!(endpoint instanceof TunnelSocketAddress))
                throw new IllegalArgumentException("Unsupported address type");
            isBinding = true;
        }
        try {
            TunnelSocketAddress tep = (TunnelSocketAddress) endpoint;
            bind(this, tunnelFactory, tep, backlog);
        }
        finally {
            isBinding = false;
        }
    }

    protected ServerSocket getTunnel() { return tunnel; }

    public boolean isClosed() {
        return closed;
    }

    public synchronized boolean isBound() {
        return bound;
    }

    // invoked while holding lock on the dispatcher
    public synchronized void setBound(Dispatcher dispatcher, int localPort, int backlog) {
        this.dispatcher = dispatcher;
        this.tunnel = dispatcher.ssock;
        this.localPort = localPort;
        this.connectionQueue = new ArrayBlockingQueue(backlog);
        this.bound = true;
    }

    // invoked by the dispatcher; connectionQueue variable is handed off safely
    // (this method can only be invoked after setBound, and obtaining lock
    // subsequently on the dispatcher)
    boolean connect(Socket s) {
        return connectionQueue.offer(s);
    }

//    public InetAddress getInetAddress() {
//        // not an Inet socket, but for security managers it is better to report
//        // that it is a "local" socket
//        return InProcSocket.inprocInetAddr;
//    }
//
    public synchronized int getLocalPort() {
        if (!isBound()) return -1;
        return localPort;
    }

    public SocketAddress getLocalSocketAddress() {
        if (!isBound()) return null;
        return new TunnelSocketAddress(tunnel.getLocalSocketAddress(), getLocalPort());
    }

    public Socket accept() throws IOException {
        ensureNotClosed();
        if (!isBound()) throw new SocketException("Socket is not bound yet");
        int soTimeout = this.soTimeout;
        Socket req;
        while (true) {
            try {
                if (soTimeout > 0) {
                    req = (Socket)connectionQueue.poll(soTimeout, TimeUnit.MILLISECONDS);
                }
                else {
                    req = (Socket)connectionQueue.take();
                }
            }
            catch (InterruptedException e) {
                throw new InterruptedIOException(e.toString());
            }

            if (req == null) {
                throw new SocketTimeoutException(
                    "Timeout on TunnelServerSocket.accept");
            }
            if (req == TERMINATOR) {
                close();
                throw new SocketException("Socket closed");
            }

            try {
                ack(req);
                if (receiveBufferSize > 0) {
                    req.setReceiveBufferSize(receiveBufferSize);
                }
                return req;
            }
            catch (IOException e) {
                // may have been cancelled by the client; take the next one
            }
        }
    }

    private static final Socket TERMINATOR = new Socket();

    public void close() throws IOException {
        synchronized (this) {
            if (closed) return;
            closed = true;
            if (!isBound()) return;
        }
        // cancel pending requests
        while (true) {
            Socket req = (Socket)connectionQueue.poll();
            if (req == null) break;
            IOUtils.closeQuietly(req);
        }
        // abort possibly blocked accept
        try {
            connectionQueue.put(TERMINATOR);
        }
        catch (InterruptedException e) {
            throw new RuntimeException("FATAL: Blocked when putting into empty queue");
        }

        unbind(this);
    }

    public void setSoTimeout(int timeout) throws SocketException {
        if (timeout < 0) throw new IllegalArgumentException("Timeout must be non-negative");
        ensureNotClosed();
        soTimeout = timeout;
    }

    public int getSoTimeout() throws IOException {
        ensureNotClosed();
        return soTimeout;
    }

    public void setReuseAddress(boolean on) throws SocketException {
        ensureNotClosed();
        // no op
    }

    public boolean getReuseAddress() throws SocketException {
        ensureNotClosed();
        return true;
    }

    public String toString() {
        if (!isBound()) return "TunnelServerSocket[unbound]";
        return "TunnelServerSocket[tunnel=" + tunnel.getLocalSocketAddress() +
               "; port=" + localPort + "]";
    }

    public synchronized void setReceiveBufferSize(int size) throws SocketException {
        if (size <= 0) throw new IllegalArgumentException();
        receiveBufferSize = size;
    }

    public synchronized int getReceiveBufferSize() throws SocketException {
        return (receiveBufferSize > 0) ? receiveBufferSize :
            dispatcher != null ? dispatcher.ssock.getReceiveBufferSize() : -1;
    }


    private void ensureNotClosed() throws SocketException {
        if (isClosed()) throw new SocketException("Socket is closed");
    }

    private static class Dispatcher {
        final TunnelFactory tunnelFactory;
        final SocketAddress requestedAddr;
        final ServerSocket ssock;
        final IntMap sockets = new IntRadkeHashMap();
        final IntSortedSet addresses = new IntIntervalSet(0, Integer.MAX_VALUE);
        volatile boolean closed = false;
        int lastDefaultPort = 0x8000;

        Dispatcher(TunnelFactory tunnelFactory, SocketAddress requestedAddr)
            throws IOException
        {
            this.tunnelFactory = tunnelFactory;
            this.requestedAddr = requestedAddr;
            this.ssock = tunnelFactory.createTunnel(requestedAddr);
        }

        synchronized void start() {
            AccessController.doPrivileged(new PrivilegedAction() {
                public Object run() {
                    ThreadGroup tg = Thread.currentThread().getThreadGroup();
                    while (tg.getParent() != null) tg = tg.getParent();
                    Runnable runLoop = new Runnable() {
                        public void run() { acceptLoop(); }};
                    Thread t = new Thread(tg, runLoop, "tunnel at " +
                                          ssock.getLocalSocketAddress());
                    t.start();
                    return null;
                }
            });
        }

        synchronized void stop() {
            closed = true;
            IOUtils.closeQuietly(ssock);
        }

        boolean isClosed() {
            return closed;
        }

        /** Returns false if the tunnel has been closed */
        synchronized boolean bindSocket(TunnelServerSocket s, int port,
                                        int backlog) throws IOException {
            if (closed)
                return false;
            if (port != 0) {
                TunnelServerSocket other = (TunnelServerSocket) sockets.get(port);
                if (other != null) {
                    throw new BindException("Address in use: " + ssock + ":" + port);
                }
            }
            else {
                // need to find a free port
                IntSortedSet unusedPorts = (IntSortedSet) addresses.complementSet();
                try {
                    port = unusedPorts.higher(lastDefaultPort);
                }
                catch (NoSuchElementException e) {
                    // wraparound
                    port = unusedPorts.first();
                }
                lastDefaultPort = port;
            }

            // have a free slot
            addresses.add(port);
            s.setBound(this, port, backlog);
            sockets.put(port, s);
            return true;
        }

        synchronized boolean tryUnbindSocket(TunnelServerSocket s) {
            int port = s.getLocalPort();
            TunnelServerSocket current = (TunnelServerSocket) sockets.get(port);
            if (current == s) {
                sockets.remove(port);
                addresses.remove(port);
                if (addresses.isEmpty()) {
                    stop();
                    return true;
                }
            }
            return false;
        }

        private void acceptLoop() {
            while (!ssock.isClosed()) {
                try {
                    Socket s = ssock.accept();
                    dispatch(s);
                }
                catch (IOException e) {}
            }
            closed = true;
        }

        private void dispatch(Socket s) {
            try {
                DataInputStream in = new DataInputStream(s.getInputStream());
                DataOutputStream out = new DataOutputStream(s.getOutputStream());
                int magic = in.readInt();
                if (magic != PROTOCOL_MAGIC) {
                    nack(s, out, NACK_UNSUPPORTED_PROTOCOL);
                    return;
                }
                int pver = in.readInt();
                if (pver != PROTOCOL_VERSION_1) {
                    nack(s, out, NACK_UNSUPPORTED_PROTOCOL);
                    return;
                }
                int port = in.readInt();
                TunnelServerSocket tss;
                synchronized (this) {
                    tss = (TunnelServerSocket) sockets.get(port);
                }
                if (tss == null) {
                    nack(s, out, NACK_NO_SERVER);
                    return;
                }
                boolean enqueued = tss.connect(s);
                if (!enqueued) {
                    nack(s, out, NACK_QUEUE_FULL);
                    return;
                }
            }
            catch (IOException e) {
                IOUtils.closeQuietly(s);
                return;
            }
        }

    }

    private static class Promise extends AsyncTask {
        Promise() {}
        void set(Object result) { super.setCompleted(result); }
    }

    private static void nack(Socket s, DataOutputStream out, int code) {
        try {
            out.writeInt(NACK_QUEUE_FULL);
            out.flush();
        }
        catch (IOException e) {
        }
        finally {
            IOUtils.closeQuietly(s);
        }
    }

    private static void ack(Socket s) throws IOException {
        DataOutputStream out = new DataOutputStream(s.getOutputStream());
        try {
            out.writeInt(ACK);
            out.flush();
        }
        catch (IOException e) {
            IOUtils.closeQuietly(s);
            throw e;
        }
    }

    static void bind(TunnelServerSocket s, TunnelFactory tunnelFactory,
                     TunnelSocketAddress addr, int backlog)
        throws IOException
    {
        Dispatcher d = null;
        Promise p = null;
        boolean goahead = false;
        SocketAddress tunnelAddr = addr.getTunnelAddress();
        int port = addr.getPort();
        TunnelKey key = new TunnelKey(tunnelFactory, tunnelAddr);
        while (true) {
            synchronized (dispatchers) {
                Object obj = dispatchers.get(key);
                if (obj == null) {
                    p = new Promise();
                    goahead = true;
                    dispatchers.put(key, p);
                }
                else if (obj instanceof Dispatcher) {
                    d = (Dispatcher)obj;
                    if (d.isClosed()) {
                        // safe to overwrite
                        d = null;
                        p = new Promise();
                        goahead = true;
                        dispatchers.put(key, p);
                    }
                }
                else {
                    p = (Promise)obj;
                    goahead = false;
                }
            }

            if (d != null) {
                if (d.bindSocket(s, port, backlog)) return;
                else continue; // try again; has been closed in the meantime
            }
            // otherwise, p != null
            if (goahead) {
                d = new Dispatcher(tunnelFactory, tunnelAddr);
                d.bindSocket(s, port, backlog);
                synchronized (dispatchers) {
                    dispatchers.put(key, d);
                }
                d.start();
                p.set(d);
                return;
            }
            else {
                try {
                    d = (Dispatcher) p.get();
                }
                catch (InterruptedException e) {
                    throw new InterruptedIOException();
                }
                catch (ExecutionException e) {
                    throw new RuntimeException("Cannot happen");
                }
                if (d.bindSocket(s, port, backlog)) return;
                else continue; // try again; has been closed in the meantime
            }
        }
    }

    // PRE: bound
    static void unbind(TunnelServerSocket s) {
        Dispatcher d = s.dispatcher;
        if (d.tryUnbindSocket(s)) {
            // dispatcher has been stopped; clean up
            Dispatcher curr;
            TunnelKey key = new TunnelKey(d.tunnelFactory, d.requestedAddr);
            synchronized (dispatchers) {
                 curr = (Dispatcher)dispatchers.get(key);
                 if (curr == d) {
                     dispatchers.remove(key);
                 }
            }
        }
    }

    public final static class PlainTunnelFactory implements TunnelFactory {
        final int defaultTunnelPort;
        final int dispatcherBacklog;

        PlainTunnelFactory() {
            this(0);
        }

        PlainTunnelFactory(int defaultTunnelPort) {
            this(defaultTunnelPort, DEFAULT_DISPATCHER_BACKLOG);
        }

        PlainTunnelFactory(int defaultTunnelPort, int dispatcherBacklog) {
            this.defaultTunnelPort = defaultTunnelPort;
            this.dispatcherBacklog = dispatcherBacklog;
        }
        public ServerSocket createTunnel(SocketAddress addr) throws IOException {
            if (addr == null) {
                addr = new InetSocketAddress((InetAddress)null, defaultTunnelPort);
            }
            else if (!(addr instanceof InetSocketAddress)) {
                throw new IllegalArgumentException("Unsupported address type");
            }
            InetSocketAddress iaddr = (InetSocketAddress)addr;
            return new ServerSocket(iaddr.getPort(), dispatcherBacklog, iaddr.getAddress());
        }

        public int hashCode() {
            return defaultTunnelPort ^ dispatcherBacklog;
        }

        public boolean equals(Object other) {
            if (other == this) { return true; }
            if (! (other instanceof PlainTunnelFactory)) return false;
            PlainTunnelFactory that = (PlainTunnelFactory) other;
            return (this.defaultTunnelPort == that.defaultTunnelPort &&
                    this.dispatcherBacklog == that.dispatcherBacklog);
        }

        public String toString() {
            return "PlainTunnelFactory(" + defaultTunnelPort + ")" ;
        }
    }

    private static class TunnelKey {
        final TunnelFactory factory;
        final SocketAddress addr;
        TunnelKey(TunnelFactory factory, SocketAddress addr) {
            this.factory = factory;
            this.addr = addr;
        }
        public int hashCode() {
            return factory.hashCode() ^ (addr == null ? 0 : addr.hashCode());
        }
        public boolean equals(Object other) {
            if (other == this) return true;
            if (! (other instanceof TunnelKey)) return false;
            TunnelKey that = (TunnelKey) other;
            return (this.factory.equals(that.factory) &&
                    (this.addr == null ? that.addr == null : this.addr.equals(that.addr)));
        }
        public String toString() {
            return "[" + factory + ": " + addr + "]";
        }
    }
}

